from __future__ import annotations
import time
import re
import optuna
import ConfigSpace as CS

from ConfigSpace.exceptions import (
    InactiveHyperparameterSetError,
    ActiveHyperparameterNotSetError,
)
from ConfigSpace.hyperparameters import (
    CategoricalHyperparameter,
    UniformFloatHyperparameter, UniformIntegerHyperparameter,
    FloatHyperparameter, IntegerHyperparameter,
    Constant,
)

from objective import Objective
from loggers import ExperimentLogger
# ---------------------------------------------------------------------
# Utility: Placeholder default values & parse HP name from exception text
# ---------------------------------------------------------------------
def _fill_default_for(hp):
    if isinstance(hp, Constant):
        return hp.value
    if getattr(hp, "default_value", None) is not None:
        return hp.default_value
    if isinstance(hp, CategoricalHyperparameter):
        return hp.choices[0]
    if isinstance(hp, (UniformFloatHyperparameter, FloatHyperparameter)):
        return float(hp.lower)
    if isinstance(hp, (UniformIntegerHyperparameter, IntegerHyperparameter)):
        return int(hp.lower)
    raise NotImplementedError(f"Unknown HP type: {type(hp).__name__} ({hp.name})")


def _extract_hp_name_from_exception_text(msg: str, cs: CS.ConfigurationSpace) -> str | None:
    """
    Parse the problematic hyperparameter name from exception text:
    - The text usually contains a line: `<name>, Type: ...`
    - First, take the token before the comma; if it's in names_set, return it
    - Otherwise, use a loose regex to capture possible names and verify
    """
    names_set = {hp.name for hp in cs.values()}
    for line in reversed((msg or "").splitlines()):
        line = line.strip()
        if not line:
            continue
        token = line.split(",", 1)[0].strip()
        if token in names_set:
            return token
        m = re.search(r"([A-Za-z0-9:_\-\.\[\]]+)", line)
        if m and m.group(1) in names_set:
            return m.group(1)
    return None


def _name_from_active_error(e: ActiveHyperparameterNotSetError, cs: CS.ConfigurationSpace) -> str | None:
    name = getattr(getattr(e, "hp", None), "name", None)
    if name:
        return name
    return _extract_hp_name_from_exception_text(str(e), cs)


def _name_from_inactive_error(e: InactiveHyperparameterSetError, cs: CS.ConfigurationSpace) -> str | None:
    name = getattr(getattr(e, "hp", None), "name", None)
    if name:
        return name
    return _extract_hp_name_from_exception_text(str(e), cs)


# ----------------------------------------------------------------------------
# Canonicalization: Let ConfigSpace remove inactive items and fill missing active items
# ---------------------------------------------------------------------
def _canonicalize(cs: CS.ConfigurationSpace, values: dict, max_iters: int | None = None) -> dict:
    """
    Iteratively use ConfigSpace validator to fix 'values' into a "legal configuration":
      - Remove keys that are inactive but assigned (InactiveHyperparameterSetError)
      - Fill missing keys that are active with default/placeholder values (ActiveHyperparameterNotSetError)
    """
    values = dict(values)
    if max_iters is None:
        max_iters = max(12, 3 * len(cs.values()))

    for _ in range(max_iters):
        try:
            cfg = CS.Configuration(cs, values=values)
            return dict(cfg)  # ✅ New API (replaces cfg.get_dictionary())
        except InactiveHyperparameterSetError as e:
            # Inactive but assigned -> remove the key
            name = _name_from_inactive_error(e, cs)
            if name:
                values.pop(name, None)
                continue
            # If name cannot be parsed (rare): conservative strategy—clear all current "suspicious" illegal items
            # Here we fallback to a minimal values dict containing only defaults/constants, then continue to fill
            values = {}
        except ActiveHyperparameterNotSetError as e:
            # Active but missing -> fill with a default/placeholder value
            name = _name_from_active_error(e, cs)
            if name:
                hp = cs[name]  # ✅ New API (replaces cs.get_hyperparameter)
                values[name] = _fill_default_for(hp)
                continue
            # If name cannot be parsed (rare): fill all constants/single-value categories, then next round
            for hp in cs.values():
                if hp.name in values:
                    continue
                if isinstance(hp, Constant) or (isinstance(hp, CategoricalHyperparameter) and len(hp.choices) == 1):
                    values[hp.name] = _fill_default_for(hp)

    # If still not converged, raise the underlying error for debugging
    CS.Configuration(cs, values=values)
    return values  # pragma: no cover


# ---------------------------------------------------------------------
# Optuna: Map from ConfigSpace to suggest_*, canonicalize after each assignment
# ---------------------------------------------------------------------
def _suggest_from_configspace(trial: optuna.Trial, cs: CS.ConfigurationSpace):
    cfg = {}
    for hp in cs.values():  # ✅ ConfigSpace 1.2+ API
        name = hp.name

        if isinstance(hp, Constant):
            cfg[name] = hp.value
            cfg = _canonicalize(cs, cfg)
            continue

        if isinstance(hp, CategoricalHyperparameter):
            val = trial.suggest_categorical(name, list(hp.choices))
        elif isinstance(hp, (UniformFloatHyperparameter, FloatHyperparameter)):
            low, high = float(hp.lower), float(hp.upper)
            log = bool(getattr(hp, "log", False) or getattr(hp, "log_scale", False))
            val = trial.suggest_float(name, low, high, log=log)
        elif isinstance(hp, (UniformIntegerHyperparameter, IntegerHyperparameter)):
            low, high = int(hp.lower), int(hp.upper)
            log = bool(getattr(hp, "log", False) or getattr(hp, "log_scale", False))
            val = trial.suggest_int(name, low, high, log=log)
        else:
            raise NotImplementedError(f"Unsupported HP type for Optuna-Random: {type(hp).__name__} ({name})")

        cfg[name] = val
        cfg = _canonicalize(cs, cfg)

    return cfg  # Always legal


# ---------------------------------------------------------------------
# Random Search (Optuna.RandomSampler)
# ---------------------------------------------------------------------
def run_random_optuna(*,
                      seed: int,
                      bench: str,
                      cs: CS.ConfigurationSpace,
                      obj: Objective,
                      budget_n: int,
                      logger: ExperimentLogger,
                      method_name: str = "RandomSearch-Optuna"):
    """
    Run random search using Optuna.RandomSampler.
    Logs: n_eval / sim_time (proxy runtime sum) / elapsed_time (real time) / curr_score / best_score / config
    """
    study = optuna.create_study(
        direction="minimize",
        sampler=optuna.samplers.RandomSampler(seed=seed),
    )

    best = float("inf")

    def objective(trial: optuna.Trial):
        cfg = _suggest_from_configspace(trial, cs)
        t0 = time.perf_counter()
        curr, sim_t = obj.evaluate(cfg)
        elapsed = time.perf_counter() - t0
        trial.set_user_attr("sim_time", sim_t)
        trial.set_user_attr("elapsed_time", elapsed)
        trial.set_user_attr("config", cfg)
        return curr

    def cb(study: optuna.Study, trial: optuna.FrozenTrial):
        if trial.value is None:
            return
        nonlocal best
        n = len([t for t in study.trials if t.value is not None])
        curr = trial.value
        best = min(best, curr)
        logger.log(dict(
            seed=seed, method=method_name, bench=bench,
            n_eval=n,
            sim_time=trial.user_attrs.get("sim_time", 0.0),
            elapsed_time=trial.user_attrs.get("elapsed_time", 0.0),
            best_score=1-best, curr_score=1-curr,
            config=trial.user_attrs.get("config", {}),
        ))

    study.optimize(objective, n_trials=budget_n, callbacks=[cb], show_progress_bar=False)
